{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3599d0f8",
   "metadata": {},
   "source": [
    "# Tutorial16: DIFFPOOL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d1eca5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n",
    "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n",
    "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d41129",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b0f2059",
   "metadata": {},
   "source": [
    "Below are shown the computation to obtain the nodes features matrix and adjacency matrix for the first hierarchical step. \n",
    "\n",
    "Initial graph: \n",
    "```x_0   = 50 x 32\n",
    "adj_0  = 50 x 50```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c7bdbf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Node features matrix\n",
    "x_0 = torch.rand(50, 32)\n",
    "adj_0 = torch.rand(50,50).round().long()\n",
    "identity = torch.eye(50)\n",
    "adj_0 = adj_0 + identity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "753e9e5c",
   "metadata": {},
   "source": [
    "Set the number of clusters we want to obtain at step 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b8647c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_clusters_0 = 50\n",
    "n_clusters_1 = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7c607e4",
   "metadata": {},
   "source": [
    "Initialize the weights of GNN_emb and GNN_pool, we use just 1 conv layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c420b7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "w_gnn_emb = torch.rand(32, 16)\n",
    "w_gnn_pool = torch.rand(32, n_clusters_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f410b12",
   "metadata": {},
   "source": [
    "<img src=\"img1.png\" width=300px>\n",
    "<img src=\"img2.png\" width=400px>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9264b7d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "z_0 = torch.relu(adj_0 @ x_0 @ w_gnn_emb)\n",
    "s_0 = torch.softmax(torch.relu(adj_0 @ x_0 @ w_gnn_pool), dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa5c0c75",
   "metadata": {},
   "source": [
    "<img src=\"img3.png\" width=200px>\n",
    "<img src=\"img4.png\" width=200px>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92465df5",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_1 = s_0.t() @ z_0\n",
    "adj_1 = s_0.t() @ adj_0 @ s_0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a8a7596",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(x_1.shape)\n",
    "print(adj_1.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c695ca4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "from math import ceil\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.datasets import TUDataset\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.data import DenseDataLoader\n",
    "from torch_geometric.nn import DenseGCNConv as GCNConv, dense_diff_pool\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3eb52795",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_nodes = 150\n",
    "\n",
    "\n",
    "class MyFilter(object):\n",
    "    def __call__(self, data):\n",
    "        return data.num_nodes <= max_nodes\n",
    "\n",
    "\n",
    "dataset = TUDataset('data', name='PROTEINS', transform=T.ToDense(max_nodes),\n",
    "                    pre_filter=MyFilter())\n",
    "dataset = dataset.shuffle()\n",
    "n = (len(dataset) + 9) // 10\n",
    "test_dataset = dataset[:n]\n",
    "val_dataset = dataset[n:2 * n]\n",
    "train_dataset = dataset[2 * n:]\n",
    "test_loader = DenseDataLoader(test_dataset, batch_size=32)\n",
    "val_loader = DenseDataLoader(val_dataset, batch_size=32)\n",
    "train_loader = DenseDataLoader(train_dataset, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "186a211b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GNN(torch.nn.Module):\n",
    "    def __init__(self, in_channels, hidden_channels, out_channels,\n",
    "                 normalize=False, lin=True):\n",
    "        super(GNN, self).__init__()\n",
    "        \n",
    "        self.convs = torch.nn.ModuleList()\n",
    "        self.bns = torch.nn.ModuleList()\n",
    "        \n",
    "        self.convs.append(GCNConv(in_channels, hidden_channels, normalize))\n",
    "        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))\n",
    "        \n",
    "        self.convs.append(GCNConv(hidden_channels, hidden_channels, normalize))\n",
    "        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))\n",
    "        \n",
    "        self.convs.append(GCNConv(hidden_channels, out_channels, normalize))\n",
    "        self.bns.append(torch.nn.BatchNorm1d(out_channels))\n",
    "\n",
    "\n",
    "    def forward(self, x, adj, mask=None):\n",
    "        batch_size, num_nodes, in_channels = x.size()\n",
    "        \n",
    "        for step in range(len(self.convs)):\n",
    "            x = self.bns[step](F.relu(self.convs[step](x, adj, mask)))\n",
    "        \n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "class DiffPool(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(DiffPool, self).__init__()\n",
    "\n",
    "        num_nodes = ceil(0.25 * max_nodes)\n",
    "        self.gnn1_pool = GNN(dataset.num_features, 64, num_nodes)\n",
    "        self.gnn1_embed = GNN(dataset.num_features, 64, 64)\n",
    "\n",
    "        num_nodes = ceil(0.25 * num_nodes)\n",
    "        self.gnn2_pool = GNN(64, 64, num_nodes)\n",
    "        self.gnn2_embed = GNN(64, 64, 64, lin=False)\n",
    "\n",
    "        self.gnn3_embed = GNN(64, 64, 64, lin=False)\n",
    "\n",
    "        self.lin1 = torch.nn.Linear(64, 64)\n",
    "        self.lin2 = torch.nn.Linear(64, dataset.num_classes)\n",
    "\n",
    "    def forward(self, x, adj, mask=None):\n",
    "        s = self.gnn1_pool(x, adj, mask)\n",
    "        x = self.gnn1_embed(x, adj, mask)\n",
    "\n",
    "        x, adj, l1, e1 = dense_diff_pool(x, adj, s, mask)\n",
    "        #x_1 = s_0.t() @ z_0\n",
    "        #adj_1 = s_0.t() @ adj_0 @ s_0\n",
    "        \n",
    "        s = self.gnn2_pool(x, adj)\n",
    "        x = self.gnn2_embed(x, adj)\n",
    "\n",
    "        x, adj, l2, e2 = dense_diff_pool(x, adj, s)\n",
    "\n",
    "        x = self.gnn3_embed(x, adj)\n",
    "\n",
    "        x = x.mean(dim=1)\n",
    "        x = F.relu(self.lin1(x))\n",
    "        x = self.lin2(x)\n",
    "        return F.log_softmax(x, dim=-1), l1 + l2, e1 + e2\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0e89716",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = DiffPool().to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "\n",
    "def train(epoch):\n",
    "    model.train()\n",
    "    loss_all = 0\n",
    "\n",
    "    for data in train_loader:\n",
    "        data = data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output, _, _ = model(data.x, data.adj, data.mask)\n",
    "        loss = F.nll_loss(output, data.y.view(-1))\n",
    "        loss.backward()\n",
    "        loss_all += data.y.size(0) * loss.item()\n",
    "        optimizer.step()\n",
    "    return loss_all / len(train_dataset)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test(loader):\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        pred = model(data.x, data.adj, data.mask)[0].max(dim=1)[1]\n",
    "        correct += pred.eq(data.y.view(-1)).sum().item()\n",
    "    return correct / len(loader.dataset)\n",
    "\n",
    "\n",
    "best_val_acc = test_acc = 0\n",
    "for epoch in range(1, 151):\n",
    "    train_loss = train(epoch)\n",
    "    val_acc = test(val_loader)\n",
    "    if val_acc > best_val_acc:\n",
    "        test_acc = test(test_loader)\n",
    "        best_val_acc = val_acc\n",
    "    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, '\n",
    "          f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "749420ff",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df48fe17",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
